

%% data and voting- and campaign-stage estimates

load('voting_stage.mat')
load('campaign_stage.mat')

% candidate valence draws
rng('default')
NS = 1e+4;
sigma_Xi = [std(xi0(1:N,:)),std(xi0(N+(1:N),:)),std(xi0(2*N+(1:N0+N2),:)),std(xi0(2*N+N0+N2+(1:N0+N1),:)),std(xi0(3*N+N0+(1:N),:))];
Xi0 = sigma_Xi .* randn(NS,5);
Xi1 = [Xi0(:,1:2),-Inf(NS,1),Xi0(:,4:5)];
Xi2 = [Xi0(:,1:3),-Inf(NS,1),Xi0(:,5)];
sigma_Xi_2S = [std(xi2S0(1:N1+N2,:)),std(xi2S0(N1+N2+(1:N1+N2),:))];
Xi1_2S = sigma_Xi_2S .* randn(NS,2);
Xi2_2S = Xi1_2S;


%% coalition's expected surpluses, spending, shares, and gradients for inference

% first-stage-voting surpluses and counterfactual spending, shares, and winning probabilities
tic
[ES0,EQ0,S0,EQU0] = surplus0(gamma,B,D,P,Xi0,[N0,N1,N2],optCS,cmax,1,1);
ES0 = mean(ES0,2);
PW0_PVEM = mean(S0(:,3,:) == max(S0,[],2),3);
PW0_PRI = mean(S0(:,4,:) == max(S0,[],2),3);
[ES1,EQ1,S1,EQU1] = surplus1(gamma,B,D,P,Xi1,[N0,N1,N2],optCS,cmax,1,1);
ES1 = mean(ES1,2);
PW1_PRI = mean(S1(:,4,:) == max(S1,[],2),3);
[ES2,EQ2,S2,EQU2] = surplus2(gamma,B,D,P,Xi2,[N0,N1,N2],optCS,cmax,1,1);
ES2 = mean(ES2,2);
PW2_PVEM = mean(S2(:,3,:) == max(S2,[],2),3);
toc
disp('first stage')

% second-stage-voting surpluses and gradients
tic
[ES1_2S,S1_2S,DES1_2S] = surplus_2S(gamma,B,B2S0,D,P,Xi1_2S,[N0,N1,N2],1);
ES1_2S = mean(ES1_2S,2);
DES1_2S = mean(DES1_2S,3);
[ES2_2S,S2_2S,DES2_2S] = surplus_2S(gamma,B,B2S0,D,P,Xi2_2S,[N0,N1,N2],2);
ES2_2S = mean(ES2_2S,2);
DES2_2S = mean(DES2_2S,3);
toc
disp('second stage')

% first-stage-voting gradients (forward differences and NS/10 to reduce computing time)
rfd = 1e-4;
DES0 = zeros(N,length(gamma)+length(B)+length(B2S0));
DES1 = zeros(N,length(gamma)+length(B)+length(B2S0));
DES2 = zeros(N,length(gamma)+length(B)+length(B2S0));
DPW0_PVEM = zeros(N,length(gamma)+length(B)+length(B2S0));
DPW0_PRI = zeros(N,length(gamma)+length(B)+length(B2S0));
% w.r.t. gamma
for k = 1:length(gamma)
    tic
    df = 10^round(log10(rfd * mean(abs(gamma))));
    gamma_df = gamma;
    gamma_df(k) = gamma_df(k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma_df,B,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma_df,B,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma_df,B,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,k) = (ES2_df - ES2) / df;
    toc
    disp('G gamma')
    disp(k)
end
% w.r.t. menu-party fixed effects
% independent candidates
tic
kk = [1,3+KD+1,2*(3+KD)+1,2*(3+KD)+2+KD+1,2*(3+KD)+2*(2+KD)+1];
df = 10^round(log10(rfd * mean(abs(B(kk)))));
for k = kk
    B_df = B;
    B_df(k) = B_df(k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+k) = (PW0_PRI_df - PW0_PRI) / df;
    toc
    disp('G beta')
    disp(length(gamma)+k)
end
% PRI coalition candidate
tic
kk = [2,3+KD+2,2*(3+KD)+2+KD+2,2*(3+KD)+2*(2+KD)+2];
df = 10^round(log10(rfd * mean(abs(B(kk)))));
for k = kk
    B_df = B;
    B_df(k) = B_df(k) + df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+k) = (ES1_df - ES1) / df;
    toc
    disp('G beta')
    disp(length(gamma)+k)
end
% PVEM coalition candidate
tic
kk = [3,3+KD+3,2*(3+KD)+2,2*(3+KD)+2*(2+KD)+3];
df = 10^round(log10(rfd * mean(abs(B(kk)))));
for k = kk
    B_df = B;
    B_df(k) = B_df(k) + df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+k)
end
% w.r.t. demand demographics
for k = 1:KD
    % MP
    tic
    df = 10^round(log10(rfd * mean(abs(B([3,3+KD+3,2*(3+KD)+2,2*(3+KD)+2+KD+2,2*(3+KD)+2*(2+KD)+3] + k)))));
    B_df = B;
    B_df(3+k) = B_df(3+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+3+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+3+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+3+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+3+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+3+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+3+k)
    % NA
    tic
    B_df = B;
    B_df(3+KD+3+k) = B_df(3+KD+3+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+3+KD+3+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+3+KD+3+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+3+KD+3+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+3+KD+3+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+3+KD+3+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+3+KD+3+k)
    % PVEM
    tic
    B_df = B;
    B_df(2*(3+KD)+2+k) = B_df(2*(3+KD)+2+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+2*(3+KD)+2+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+2*(3+KD)+2+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+2*(3+KD)+2+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+2*(3+KD)+2+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+2*(3+KD)+2+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+2*(3+KD)+2+k)
    % PRI
    tic
    B_df = B;
    B_df(2*(3+KD)+2+KD+2+k) = B_df(2*(3+KD)+2+KD+2+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+2*(3+KD)+2+KD+2+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+2*(3+KD)+2+KD+2+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+2*(3+KD)+2+KD+2+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+2*(3+KD)+2+KD+2+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+2*(3+KD)+2+KD+2+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+2*(3+KD)+2+KD+2+k)
    % PAN
    tic
    B_df = B;
    B_df(2*(3+KD)+2*(2+KD)+3+k) = B_df(2*(3+KD)+2*(2+KD)+3+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+2*(3+KD)+2*(2+KD)+3+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+2*(3+KD)+2*(2+KD)+3+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+2*(3+KD)+2*(2+KD)+3+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+2*(3+KD)+2*(2+KD)+3+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+2*(3+KD)+2*(2+KD)+3+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+2*(3+KD)+2*(2+KD)+3+k)
end
% w.r.t. lagged vote share and campaign spending
tic
for k = 1:3
    df = 10^round(log10(rfd * abs(B(3*(3+KD)+2*(2+KD) + k))));
    B_df = B;
    B_df(3*(3+KD)+2*(2+KD)+k) = B_df(3*(3+KD)+2*(2+KD)+k) + df;
    [ES0_df,~,S0_df] = surplus0(gamma,B_df,D,P,Xi0(1:NS/10,:),[N0,N1,N2],optCS,cmax,1,0);
    ES0_df = mean(ES0_df,2);
    DES0(:,length(gamma)+3*(3+KD)+2*(2+KD)+k) = (ES0_df - ES0) / df;
    PW0_PVEM_df = mean(S0_df(:,3,:) == max(S0_df,[],2),3);
    DPW0_PVEM(:,length(gamma)+3*(3+KD)+2*(2+KD)+k) = (PW0_PVEM_df - PW0_PVEM) / df;
    PW0_PRI_df = mean(S0_df(:,4,:) == max(S0_df,[],2),3);
    DPW0_PRI(:,length(gamma)+3*(3+KD)+2*(2+KD)+k) = (PW0_PRI_df - PW0_PRI) / df;
    ES1_df = surplus1(gamma,B_df,D,P,Xi1(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES1_df = mean(ES1_df,2);
    DES1(:,length(gamma)+3*(3+KD)+2*(2+KD)+k) = (ES1_df - ES1) / df;
    ES2_df = surplus2(gamma,B_df,D,P,Xi2(1:NS/10,:),[N0,N1,N2],optCS,cmax,0,0);
    ES2_df = mean(ES2_df,2);
    DES2(:,length(gamma)+3*(3+KD)+2*(2+KD)+k) = (ES2_df - ES2) / df;
    toc
    disp('G beta')
    disp(length(gamma)+3*(3+KD)+2*(2+KD)+k)
end

% total surpluses and gradients
ES1 = ES1 + ES1_2S;
DES1 = DES1 + DES1_2S;
ES2 = ES2 + ES2_2S;
DES2 = DES2 + DES2_2S;


%%

clearvars -except NS sigma_Xi Xi0 Xi1 Xi2 sigma_Xi_2S Xi1_2S Xi2_2S ...
    ES0 EQ0 S0 EQU0 PW0_PVEM PW0_PRI ...
    ES1 EQ1 S1 EQU1 PW1_PRI ES2 EQ2 S2 EQU2 PW2_PVEM ...
    ES1_2S S1_2S DES1_2S ES2_2S S2_2S DES2_2S ...
    rfd DES0 DES1 DES2 DPW0_PVEM DPW0_PRI

save('coalition_stage_p1.mat')



